{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# CPU Operator Customization with Numba\n",
    "\n",
    "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/brainpy/brainpy/blob/master/docs_version2/tutorial_advanced/operator_custom_with_numba.ipynb)\n",
    "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/brainpy/brainpy/blob/master/docs_version2/tutorial_advanced/operator_custom_with_numba.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## English version"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "Brain dynamics is sparse and event-driven, however, proprietary operators for brain dynamics are not well abstracted and summarized. As a result, we are often faced with the need to customize operators. In this tutorial, we will explore how to customize brain dynamics operators using Numba.\n",
    "\n",
    "Start by importing the relevant Python package."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-10T22:58:55.444792400Z",
     "start_time": "2023-10-10T22:58:55.368614800Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import brainpy as bp\n",
    "import brainpy.math as bm\n",
    "\n",
    "import jax\n",
    "from jax import jit\n",
    "import jax.numpy as jnp\n",
    "from jax.core import ShapedArray\n",
    "\n",
    "import numba\n",
    "\n",
    "bm.set_platform('cpu')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### ``brainpy.math.CustomOpByNumba``\n",
    "\n",
    "BrainPy provides ``brainpy.math.CustomOpByNumba`` for customizing the operator on the CPU device. Two parameters are required to provide in ``CustomOpByNumba``:\n",
    "\n",
    "- ``eval_shape``: evaluates the *shape* and *datatype* of the output argument based on the *shape* and *datatype* of the input argument.\n",
    "- `con_compute`: receives the input parameters and performs a specific computation based on them.\n",
    "\n",
    "Suppose here we want to customize an operator that does the ``b = a+1`` operation. First, define an ``eval_shape`` function. The arguments to this function are information about all the input parameters, and the return value is information about the output parameters.\n",
    "\n",
    "```python\n",
    "from jax.core import ShapedArray\n",
    "\n",
    "def eval_shape(a):\n",
    "  b = ShapedArray(a.shape, dtype=a.dtype)\n",
    "  return b\n",
    "```\n",
    "\n",
    "Since ``b`` in ``b = a + 1`` has the same type and shape as ``a``, the ``eval_shape`` function returns the same shape and type. Next, we need to define ``con_compute``. ``con_compute`` takes only ``(outs, ins)`` arguments, where all return values are inside ``outs`` and all input arguments are inside ``ins``.\n",
    "\n",
    "\n",
    "```python\n",
    "def con_compute(outs, ins):\n",
    "  b = outs\n",
    "  a = ins\n",
    "  b[:] = a + 1\n",
    "```\n",
    "\n",
    "Unlike the ``eval_shape`` function, the ``con_compute`` function does not support any return values. Instead, all output must just be updated in-place. Also, the ``con_compute`` function must follow the specification of Numba's just-in-time compilation, see:\n",
    "\n",
    "- https://numba.pydata.org/numba-doc/latest/reference/pysupported.html\n",
    "- https://numba.pydata.org/numba-doc/latest/reference/numpysupported.html\n",
    "\n",
    "Also, ``con_compute`` can be customized according to Numba's just-in-time compilation policy. For example, if JIT is just turned on, then you can use:\n",
    "\n",
    "```python\n",
    "@numba.njit\n",
    "def con_compute(outs, ins):\n",
    "  b = outs\n",
    "  a = ins\n",
    "  b[:] = a + 1\n",
    "```\n",
    "\n",
    "If the parallel computation with multiple cores is turned on, you can use:\n",
    "\n",
    "\n",
    "```python\n",
    "@numba.njit(parallel=True)\n",
    "def con_compute(outs, ins):\n",
    "  b = outs\n",
    "  a = ins\n",
    "  b[:] = a + 1\n",
    "```\n",
    "\n",
    "\n",
    "For more advanced usage, we encourage readers to read the [Numba online manual](https://numba.pydata.org/numba-doc/latest/index.html).\n",
    "\n",
    "Finally, this customized operator can be registered and used as:\n",
    "\n",
    "```bash\n",
    "\n",
    ">>> op = bm.CustomOpByNumba(eval_shape, con_compute, multiple_results=False)\n",
    ">>> op(bm.zeros(10))\n",
    "[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "#### Return multiple values ``multiple_returns=True``\n",
    "\n",
    "If the result of our computation needs to return multiple arrays, then we need to use ``multiple_returns=True`` in our use of registering the operator. In this case, ``outs`` will be a list containing multiple arrays, not an array.\n",
    "\n",
    "\n",
    "```python\n",
    "def eval_shape2(a, b):\n",
    "  c = ShapedArray(a.shape, dtype=a.dtype)\n",
    "  d = ShapedArray(b.shape, dtype=b.dtype)\n",
    "  return c, d\n",
    "\n",
    "def con_compute2(outs, ins):\n",
    "  c = outs[0]  # take out all the outputs\n",
    "  d = outs[1]\n",
    "  a = ins[0]  # take out all the inputs\n",
    "  b = ins[1]\n",
    "  c[:] = a + 1\n",
    "  d[:] = a * 2\n",
    "\n",
    "op2 = bm.CustomOpByNumba(eval_shape2, con_compute2, multiple_results=True)\n",
    "```\n",
    "\n",
    "```bash\n",
    ">>> op2(bm.zeros(10), bm.ones(10))\n",
    "([1. 1. 1. 1. 1. 1. 1. 1. 1. 1.],\n",
    " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.])\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "#### Non-Tracer parameters\n",
    "\n",
    "In the ``eval_shape`` function, all arguments are abstract information (containing only the shape and type) if they are arguments that can be traced by ``jax.jit``. However, if we infer the output data type requires additional information beyond the input parameter information, then we need to define non-Tracer parameters.\n",
    "\n",
    "For an operator defined by ``brainpy.math.CustomOpByNumba``, non-Tracer parameters are often then parameters passed in via key-value pairs such as ``key=value``. For example:\n",
    "\n",
    "```python\n",
    "op2(a, b, c, d=d, e=e)\n",
    "```\n",
    "\n",
    "``a, b, c`` are all ``jax.jit`` traceable parameters, and ``d`` and ``e`` are deterministic, non-tracer parameters. Therefore, in the ``eval_shape(a, b, c, d, e)`` function, ``a, b, c`` will be ``SharedArray``, and ``d`` and ``e`` will be concrete values.\n",
    "\n",
    "For another example, \n",
    "\n",
    "```python\n",
    "\n",
    "def eval_shape3(a, *, b):\n",
    "  return SharedArray(b, a.dtype)  # The shape of the return value is determined by the input b\n",
    "\n",
    "def con_compute3(outs, ins):\n",
    "  c = outs  # Take out all the outputs\n",
    "  a = ins[0] # Take out all inputs\n",
    "  b = ins[1]\n",
    "  c[:] = 2.\n",
    "\n",
    "op3 = bm.CustomOpByNumba(eval_shape3, con_compute3, multiple_results=False)\n",
    "```\n",
    "\n",
    "```bash\n",
    ">>> op3(bm.zeros(4), 5)\n",
    "[2. 2. 2. 2. 2.]\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "... note:\n",
    "\n",
    "    It is worth noting that all arguments will be converted to arrays. Both Tracer and non-Tracer parameters are arrays in ``con_compute``. For example, ``1`` is passed in, but in ``con_compute`` it's a 0-dimensional array ``1``; ``(1, 2)`` is passed in, and in ``con_compute`` it will be the 1-dimensional array ``array([1, 2])``.\n",
    " "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "#### Example: A sparse operator\n",
    "\n",
    "To illustrate the effectiveness of this approach, we define in this an event-driven sparse computation operator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-10T22:58:55.539425400Z",
     "start_time": "2023-10-10T22:58:55.398947400Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def abs_eval(data, indices, indptr, vector, shape):\n",
    "  out_shape = shape[0]\n",
    "  return ShapedArray((out_shape,), data.dtype),\n",
    "\n",
    "@numba.njit(fastmath=True)\n",
    "def sparse_op(outs, ins):\n",
    "  res_val = outs[0]\n",
    "  res_val.fill(0)\n",
    "  values, col_indices, row_ptr, vector, shape = ins\n",
    "\n",
    "  for row_i in range(shape[0]):\n",
    "      v = vector[row_i]\n",
    "      for j in range(row_ptr[row_i], row_ptr[row_i + 1]):\n",
    "          res_val[col_indices[j]] += values * v\n",
    "\n",
    "sparse_cus_op = bm.CustomOpByNumba(eval_shape=abs_eval, con_compute=sparse_op)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "Let's try to use sparse matrix vector multiplication operator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-10T22:58:57.856525300Z",
     "start_time": "2023-10-10T22:58:55.414106700Z"
    },
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Array([ -2.2834747, -52.950108 ,  -5.0921535, ..., -40.264236 ,\n",
       "        -27.219269 ,  33.138054 ], dtype=float32)]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "size = 5000\n",
    "\n",
    "vector = bm.random.randn(size)\n",
    "sparse_A = bp.conn.FixedProb(prob=0.1, allow_multi_conn=True)(size, size).require('pre2post')\n",
    "f = jit(lambda a: sparse_cus_op(a, sparse_A[0], sparse_A[1], vector, shape=(size, size)))\n",
    "f(1.)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### brainpy.math.XLACustomOp\n",
    "\n",
    "`brainpy.math.XLACustomOp` is a new method for customizing operators on the CPU device. It is similar to `brainpy.math.CustomOpByNumba`, but it is more flexible and supports more advanced features. If you want to use this new method with numba, you only need to define a kernel using @numba.jit or @numba.njit, and then pass the kernel to `brainpy.math.XLACustomOp`.\n",
    "\n",
    "Detailed steps are as follows:\n",
    "\n",
    "#### Define the kernel\n",
    "\n",
    "```python\n",
    "@numba.njit(fastmath=True)\n",
    "def numba_event_csrmv(weight, indices, vector, outs):\n",
    "  outs.fill(0)\n",
    "  weight = weight[()]  # 0d\n",
    "  for row_i in range(vector.shape[0]):\n",
    "    if vector[row_i]:\n",
    "      for j in indices[row_i]:\n",
    "        outs[j] += weight\n",
    "```\n",
    "\n",
    "In the declaration of parameters, the last few parameters need to be output parameters so that numba can compile correctly. This operator numba_event_csrmv receives four parameters: `weight`, `indices`, `vector`, and `outs`. The first three parameters are input parameters, and the last parameter is the output parameter. The output parameter is a 1D array, and the input parameters are 0D, 1D, and 2D arrays, respectively."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Registering and Using Custom Operators\n",
    "After defining a custom operator, it can be registered into a specific framework and used where needed. When registering, you can specify cpu_kernel and gpu_kernel, so the operator can run on different devices. Specify the outs parameter when calling, using `jax.ShapeDtypeStruct` to define the shape and data type of the output.\n",
    "\n",
    "Note: Maintain the order of the operator's declared parameters consistent with the order when calling.\n",
    "\n",
    "```python\n",
    "prim = bm.XLACustomOp(cpu_kernel=numba_event_csrmv)\n",
    "indices = bm.random.randint(0, s, (s, 80))\n",
    "vector = bm.random.rand(s) < 0.1\n",
    "out = prim(1., indices, vector, outs=[jax.ShapeDtypeStruct([s], dtype=bm.float32)])\n",
    "print(out)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## 中文版"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "大脑动力学具有稀疏和事件驱动的特性，然而，大脑动力学的专有算子并没有很好的抽象和总结。因此，我们往往面临着自定义算子的需求。在这个教程中，我们将探索如何使用Numba来自定义脑动力学算子。\n",
    "\n",
    "首先引入相关的Python包。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-10T22:58:57.858443100Z",
     "start_time": "2023-10-10T22:58:57.842107200Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import brainpy as bp\n",
    "import brainpy.math as bm\n",
    "\n",
    "import jax\n",
    "from jax import jit\n",
    "import jax.numpy as jnp\n",
    "from jax.core import ShapedArray\n",
    "\n",
    "import numba\n",
    "\n",
    "bm.set_platform('cpu')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### ``brainpy.math.CustomOpByNumba``接口\n",
    "\n",
    "``brainpy.math.CustomOpByNumba`` 也叫做``brainpy.math.XLACustomOp``。\n",
    "\n",
    "BrainPy提供了``brainpy.math.CustomOpByNumba``用于自定义CPU上的算子。使用``CustomOpByNumba``需要提供两个接口：\n",
    "\n",
    "- `eval_shape`: 根据输入参数的形状(shape)和数据类型(dtype)来评估输出参数的形状和数据类型。\n",
    "- `con_compute`: 接收真正的参数，并根据参数进行具体计算。\n",
    "\n",
    "假如在这里我们要自定义一个做``b = a+1``操作的算子。首先，定义一个``eval_shape``函数。该函数的参数是所有输入变量的信息，返回值是输出参数的信息。\n",
    "\n",
    "```python\n",
    "from jax.core import ShapedArray\n",
    "\n",
    "def eval_shape(a):\n",
    "  b = ShapedArray(a.shape, dtype=a.dtype)\n",
    "  return b\n",
    "```\n",
    "\n",
    "由于``b = a + 1``中``b``与``a``具有同样的类型和形状，因此``eval_shape``函数返回一样的形状和类型。接下来，我们就需要定义``con_compute``。``con_compute``只接收``(outs, ins)``参数，其中，所有的返回值都在``outs``内，所有的输入参数都在``ins``内。\n",
    "\n",
    "\n",
    "```python\n",
    "\n",
    "def con_compute(outs, ins):\n",
    "  b = outs\n",
    "  a = ins\n",
    "  b[:] = a + 1\n",
    "```\n",
    "\n",
    "与``eval_shape``函数不同，``con_compute``函数不接收任何返回值。相反，所有的输出都必须通过in-place update的形式就行。另外，``con_compute``函数必须遵循Numba即时编译的规范，见：\n",
    "\n",
    "- https://numba.pydata.org/numba-doc/latest/reference/pysupported.html\n",
    "- https://numba.pydata.org/numba-doc/latest/reference/numpysupported.html\n",
    "\n",
    "同时，``con_compute``也可以自定义Numba的即时编译策略。比如，如果只是开启JIT，那么可以用：\n",
    "\n",
    "```python\n",
    "@numba.njit\n",
    "def con_compute(outs, ins):\n",
    "  b = outs\n",
    "  a = ins\n",
    "  b[:] = a + 1\n",
    "```\n",
    "\n",
    "如果是开始并行计算利用多核，可以使用：\n",
    "\n",
    "\n",
    "```python\n",
    "@numba.njit(parallel=True)\n",
    "def con_compute(outs, ins):\n",
    "  b = outs\n",
    "  a = ins\n",
    "  b[:] = a + 1\n",
    "```\n",
    "\n",
    "\n",
    "更多高级用法，建议读者们阅读[Numba在线手册](https://numba.pydata.org/numba-doc/latest/index.html)。\n",
    "\n",
    "最后，我们自定义这个算子可以使用：\n",
    "\n",
    "```bash\n",
    "\n",
    ">>> op = bm.CustomOpByNumba(eval_shape, con_compute, multiple_results=False)\n",
    ">>> op(bm.zeros(10))\n",
    "[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "#### 返回多个值 ``multiple_returns=True``\n",
    "\n",
    "如果我们的计算结果需要返回多个数组，那么，我们在注册算子的使用需要使用``multiple_returns=True``。此时，``outs``将会是一个包含多个数组的列表，而不是一个数组。\n",
    "\n",
    "```python\n",
    "def eval_shape2(a, b):\n",
    "  c = ShapedArray(a.shape, dtype=a.dtype)\n",
    "  d = ShapedArray(b.shape, dtype=b.dtype)\n",
    "  return c, d  # 返回多个抽象数组信息\n",
    "\n",
    "def con_compute2(outs, ins):\n",
    "  c = outs[0]  # 取出所有的输出\n",
    "  d = outs[1]\n",
    "  a = ins[0]  # 取出所有的输入\n",
    "  b = ins[1]\n",
    "  c[:] = a + 1\n",
    "  d[:] = a * 2\n",
    "\n",
    "op2 = bm.CustomOpByNumba(eval_shape2, con_compute2, multiple_results=True)\n",
    "```\n",
    "\n",
    "```bash\n",
    ">>> op2(bm.zeros(10), bm.ones(10))\n",
    "([1. 1. 1. 1. 1. 1. 1. 1. 1. 1.],\n",
    " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.])\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "#### 非Tracer参数\n",
    "\n",
    "在``eval_shape``函数中推断数据类型时，如果所有参数都是可以被``jax.jit``追踪的参数，那么所有参数都是抽象信息（只包含形状和类型）。如果有时推断输出数据类型时还需要除输入参数信息以外的额外信息，此时我们需要定义非Tracer参数。\n",
    "\n",
    "对于一个由``brainpy.math.CustomOpByNumba``定义的算子，非Tracer参数往往那么通过``key=value``等键值对传入的参数。比如，\n",
    "\n",
    "```python\n",
    "op2(a, b, c, d=d, e=e)\n",
    "```\n",
    "\n",
    "``a, b, c``都是可被`jax.jit`追踪的参数，`d`和`e`是确定性的、非Tracer参数。此时，``eval_shape(a, b, c, d, e)``函数中，a，b，c都是``SharedArray``，而d和e都是具体的数值，\n",
    "\n",
    "举个例子，\n",
    "\n",
    "```python\n",
    "\n",
    "def eval_shape3(a, *, b):\n",
    "  return SharedArray(b, a.dtype)  # 返回值的形状由输入b决定\n",
    "\n",
    "def con_compute3(outs, ins):\n",
    "  c = outs  # 取出所有的输出\n",
    "  a = ins[0] # 取出所有的输入\n",
    "  b = ins[1]\n",
    "  c[:] = 2.\n",
    "\n",
    "op3 = bm.CustomOpByNumba(eval_shape3, con_compute3, multiple_results=False)\n",
    "```\n",
    "\n",
    "```bash\n",
    ">>> op3(bm.zeros(4), 5)\n",
    "[2. 2. 2. 2. 2.]\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "\n",
    "... note::\n",
    "\n",
    "    值得注意的是，所有的输入值都将被转化成数组。无论是Tracer还是非Tracer参数，在``con_compute``中都是数组。比如传入的是``1``，但在``con_compute``中是0维数组``1``；传入的是``(1, 2)``，在``con_compute``中将是1维数组``array([1, 2])``。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "#### 示例：一个稀疏算子\n",
    "\n",
    "为了说明这种方法的有效性，我们在这个定义一个事件驱动的稀疏计算算子。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-10T22:58:57.858443100Z",
     "start_time": "2023-10-10T22:58:57.849184700Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def abs_eval(data, indices, indptr, vector, shape):\n",
    "  out_shape = shape[0]\n",
    "  return [ShapedArray((out_shape,), data.dtype)]\n",
    "\n",
    "@numba.njit(fastmath=True)\n",
    "def sparse_op(outs, ins):\n",
    "  res_val = outs[0]\n",
    "  res_val.fill(0)\n",
    "  values, col_indices, row_ptr, vector, shape = ins\n",
    "\n",
    "  for row_i in range(shape[0]):\n",
    "      v = vector[row_i]\n",
    "      for j in range(row_ptr[row_i], row_ptr[row_i + 1]):\n",
    "          res_val[col_indices[j]] += values * v\n",
    "\n",
    "sparse_cus_op = bm.CustomOpByNumba(eval_shape=abs_eval, con_compute=sparse_op)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "使用该算子我们可以用："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-10T22:58:58.245683200Z",
     "start_time": "2023-10-10T22:58:57.853019500Z"
    },
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Array([ 17.464092,  -9.924386, -33.09052 , ..., -37.2057  , -12.551924,\n",
       "         -9.046049], dtype=float32)]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "size = 5000\n",
    "\n",
    "vector = bm.random.randn(size)\n",
    "sparse_A = bp.conn.FixedProb(prob=0.1, allow_multi_conn=True)(size, size).require('pre2post')\n",
    "f = jit(lambda a: sparse_cus_op(a, sparse_A[0], sparse_A[1], vector, shape=(size, size)))\n",
    "f(1.)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### brainpy.math.XLACustomOp\n",
    "\n",
    "`brainpy.math.XLACustomOp` is a new method for customizing operators on the CPU device. It is similar to `brainpy.math.CustomOpByNumba`, but it is more flexible and supports more advanced features. If you want to use this new method with numba, you only need to define a kernel using `@numba.jit` or `@numba.njit` decorator, and then pass the kernel to `brainpy.math.XLACustomOp`.\n",
    "`brainpy.math.XLACustomOp`是一种自定义算子的新方法。它类似于`brainpy.math.CustomOpByNumba`，但它更灵活并支持更高级的特性。如果您想用numba使用这种新方法，只需要使用 `@numba.jit`或`@numba.njit`装饰器定义一个kernel，然后将内核传递给`brainpy.math.XLACustomOp`。\n",
    "\n",
    "详细步骤如下：\n",
    "\n",
    "#### 定义kernel\n",
    "在参数声明中，最后几个参数需要是输出参数，这样numba才能正确编译。这个算子`numba_event_csrmv`接受四个参数：weight、indices、vector 和 outs。前三个参数是输入参数，最后一个参数是输出参数。输出参数是一个一维数组，输入参数分别是 0D、1D 和 2D 数组。\n",
    "\n",
    "```python\n",
    "@numba.njit(fastmath=True)\n",
    "def numba_event_csrmv(weight, indices, vector, outs):\n",
    "  outs.fill(0)\n",
    "  weight = weight[()]  # 0d\n",
    "  for row_i in range(vector.shape[0]):\n",
    "    if vector[row_i]:\n",
    "      for j in indices[row_i]:\n",
    "        outs[j] += weight\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 注册并使用自定义算子\n",
    "在定义了自定义算子之后，可以将其注册到特定框架中，并在需要的地方使用它。在注册时可以指定`cpu_kernel`和`gpu_kernel`，这样算子就可以在不同的设备上运行。并在调用中指定`outs`参数，用`jax.ShapeDtypeStruct`来指定输出的形状和数据类型。\n",
    "\n",
    "注意： 在算子声明的参数与调用时需要保持顺序的一致。\n",
    "\n",
    "```python\n",
    "prim = bm.XLACustomOp(cpu_kernel=numba_event_csrmv)\n",
    "indices = bm.random.randint(0, s, (s, 80))\n",
    "vector = bm.random.rand(s) < 0.1\n",
    "out = prim(1., indices, vector, outs=[jax.ShapeDtypeStruct([s], dtype=bm.float32)])\n",
    "print(out)\n",
    "```"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
